{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multidoc Autoretrieval Pack\n",
    "\n",
    "This is the LlamaPack version of our structured hierarchical retrieval guide in the [core repo](https://docs.llamaindex.ai/en/stable/examples/query_engine/multi_doc_auto_retrieval/multi_doc_auto_retrieval.html)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup and Download Data\n",
    "\n",
    "In this section, we'll load in LlamaIndex Github issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install llama-index-readers-github\n",
    "%pip install llama-index-vector-stores-weaviate\n",
    "%pip install llama-index-llms-openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"GITHUB_TOKEN\"] = \"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "from llama_index.readers.github import GitHubRepositoryIssuesReader, GitHubIssuesClient\n",
    "\n",
    "github_client = GitHubIssuesClient()\n",
    "loader = GitHubRepositoryIssuesReader(\n",
    "    github_client,\n",
    "    owner=\"run-llama\",\n",
    "    repo=\"llama_index\",\n",
    "    verbose=True,\n",
    ")\n",
    "\n",
    "orig_docs = loader.load_data()\n",
    "\n",
    "limit = 100\n",
    "\n",
    "docs = []\n",
    "for idx, doc in enumerate(orig_docs):\n",
    "    doc.metadata[\"index_id\"] = doc.id_\n",
    "    if idx >= limit:\n",
    "        break\n",
    "    docs.append(doc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "import asyncio\n",
    "from tqdm.asyncio import tqdm_asyncio\n",
    "from llama_index.core.indices import SummaryIndex\n",
    "from llama_index.core import Document, ServiceContext\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.core.async_utils import run_jobs\n",
    "\n",
    "\n",
    "async def aprocess_doc(doc, include_summary: bool = True):\n",
    "    \"\"\"Process doc.\"\"\"\n",
    "    print(f\"Processing {doc.id_}\")\n",
    "    metadata = doc.metadata\n",
    "\n",
    "    date_tokens = metadata[\"created_at\"].split(\"T\")[0].split(\"-\")\n",
    "    year = int(date_tokens[0])\n",
    "    month = int(date_tokens[1])\n",
    "    day = int(date_tokens[2])\n",
    "\n",
    "    assignee = \"\" if \"assignee\" not in doc.metadata else doc.metadata[\"assignee\"]\n",
    "    size = \"\"\n",
    "    if len(doc.metadata[\"labels\"]) > 0:\n",
    "        size_arr = [l for l in doc.metadata[\"labels\"] if \"size:\" in l]\n",
    "        size = size_arr[0].split(\":\")[1] if len(size_arr) > 0 else \"\"\n",
    "    new_metadata = {\n",
    "        \"state\": metadata[\"state\"],\n",
    "        \"year\": year,\n",
    "        \"month\": month,\n",
    "        \"day\": day,\n",
    "        \"assignee\": assignee,\n",
    "        \"size\": size,\n",
    "        \"index_id\": doc.id_,\n",
    "    }\n",
    "\n",
    "    # now extract out summary\n",
    "    summary_index = SummaryIndex.from_documents([doc])\n",
    "    query_str = \"Give a one-sentence concise summary of this issue.\"\n",
    "    query_engine = summary_index.as_query_engine(\n",
    "        service_context=ServiceContext.from_defaults(llm=OpenAI(model=\"gpt-3.5-turbo\"))\n",
    "    )\n",
    "    summary_txt = str(query_engine.query(query_str))\n",
    "\n",
    "    new_doc = Document(text=summary_txt, metadata=new_metadata)\n",
    "    return new_doc\n",
    "\n",
    "\n",
    "async def aprocess_docs(docs):\n",
    "    \"\"\"Process metadata on docs.\"\"\"\n",
    "\n",
    "    new_docs = []\n",
    "    tasks = []\n",
    "    for doc in docs:\n",
    "        task = aprocess_doc(doc)\n",
    "        tasks.append(task)\n",
    "\n",
    "    new_docs = await run_jobs(tasks, show_progress=True, workers=5)\n",
    "\n",
    "    # new_docs = await tqdm_asyncio.gather(*tasks)\n",
    "\n",
    "    return new_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_docs = await aprocess_docs(docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_docs[5].metadata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Weaviate Indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.vector_stores.weaviate import WeaviateVectorStore\n",
    "from llama_index.core import StorageContext\n",
    "from llama_index.core import VectorStoreIndex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import weaviate\n",
    "\n",
    "# cloud\n",
    "auth_config = weaviate.AuthApiKey(api_key=\"\")\n",
    "client = weaviate.Client(\n",
    "    \"https://<weaviate-cluster>.weaviate.network\",\n",
    "    auth_client_secret=auth_config,\n",
    ")\n",
    "\n",
    "doc_metadata_index_name = \"LlamaIndex_auto\"\n",
    "doc_chunks_index_name = \"LlamaIndex_AutoDoc\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# optional: delete schema\n",
    "client.schema.delete_class(doc_metadata_index_name)\n",
    "client.schema.delete_class(doc_chunks_index_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup Metadata Schema\n",
    "\n",
    "This is required for autoretrieval; we put this in the prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.vector_stores import MetadataInfo, VectorStoreInfo\n",
    "\n",
    "\n",
    "vector_store_info = VectorStoreInfo(\n",
    "    content_info=\"Github Issues\",\n",
    "    metadata_info=[\n",
    "        MetadataInfo(\n",
    "            name=\"state\",\n",
    "            description=\"Whether the issue is `open` or `closed`\",\n",
    "            type=\"string\",\n",
    "        ),\n",
    "        MetadataInfo(\n",
    "            name=\"year\",\n",
    "            description=\"The year issue was created\",\n",
    "            type=\"integer\",\n",
    "        ),\n",
    "        MetadataInfo(\n",
    "            name=\"month\",\n",
    "            description=\"The month issue was created\",\n",
    "            type=\"integer\",\n",
    "        ),\n",
    "        MetadataInfo(\n",
    "            name=\"day\",\n",
    "            description=\"The day issue was created\",\n",
    "            type=\"integer\",\n",
    "        ),\n",
    "        MetadataInfo(\n",
    "            name=\"assignee\",\n",
    "            description=\"The assignee of the ticket\",\n",
    "            type=\"string\",\n",
    "        ),\n",
    "        MetadataInfo(\n",
    "            name=\"size\",\n",
    "            description=\"How big the issue is (XS, S, M, L, XL, XXL)\",\n",
    "            type=\"string\",\n",
    "        ),\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Download LlamaPack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.llama_pack import download_llama_pack\n",
    "\n",
    "MultiDocAutoRetrieverPack = download_llama_pack(\n",
    "    \"MultiDocAutoRetrieverPack\", \"./multidoc_autoretriever_pack\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pack = MultiDocAutoRetrieverPack(\n",
    "    client,\n",
    "    doc_metadata_index_name,\n",
    "    doc_chunks_index_name,\n",
    "    new_docs,\n",
    "    docs,\n",
    "    vector_store_info,\n",
    "    auto_retriever_kwargs={\n",
    "        \"verbose\": True,\n",
    "        \"similarity_top_k\": 2,\n",
    "        \"empty_query_top_k\": 10,\n",
    "    },\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run LlamaPack\n",
    "\n",
    "Now let's try the LlamaPack on some queries! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = pack.run(\"Tell me about some issues on 12/11\")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = pack.run(\"Tell me about some open issues related to agents\")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Retriever-only\n",
    "\n",
    "We can also get the retriever module and just run that."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever = pack.get_modules()[\"recursive_retriever\"]\n",
    "nodes = retriever.retrieve(\"Tell me about some open issues related to agents\")\n",
    "print(f\"Number of source nodes: {len(nodes)}\")\n",
    "nodes[0].node.metadata"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_hub",
   "language": "python",
   "name": "llama_hub"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
